{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vwELCooy4ljr"
      },
      "source": [
        "# Extractive QA to build structured data\n",
        "\n",
        "Traditional ETL/data parsing systems establish rules to extract information of interest. Regular expressions, string parsing and similar methods define fixed rules. This works in many cases but what if you are working with unstructured data containing numerous variations? The rules can be cumbersome and hard to maintain over time.\n",
        "\n",
        "This notebook uses machine learning and extractive question-answering (QA) to utilize the vast knowledge built into large language models. These models have been trained on extremely large datasets, learning the many variations of natural language. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ew7orE2O441o"
      },
      "source": [
        "# Install dependencies\n",
        "\n",
        "Install `txtai` and all dependencies."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LPQTb25tASIG"
      },
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/neuml/txtai"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_YnqorRKAbLu"
      },
      "source": [
        "# Train a QA model with few-shot learning\n",
        "\n",
        "The code below trains a new QA model using a few examples. These examples gives the model hints on the type of questions that will be asked and the type of answers to look for. It doesn't take a lot of examples to do this as shown below."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OUc9gqTyAYnm",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 75
        },
        "outputId": "7e7f93c7-5ad5-46a6-d04d-c7450d246f6c"
      },
      "source": [
        "import pandas as pd\n",
        "from txtai.pipeline import HFTrainer, Questions, Labels\n",
        "\n",
        "# Training data for few-shot learning\n",
        "data = [\n",
        "    {\"question\": \"What is the url?\",\n",
        "     \"context\": \"Faiss (https://github.com/facebookresearch/faiss) is a library for efficient similarity search.\",\n",
        "     \"answers\": \"https://github.com/facebookresearch/faiss\"},\n",
        "    {\"question\": \"What is the url\", \"context\": \"The last release was Wed Sept 25 2021\", \"answers\": None},\n",
        "    {\"question\": \"What is the date?\", \"context\": \"The last release was Wed Sept 25 2021\", \"answers\": \"Wed Sept 25 2021\"},\n",
        "    {\"question\": \"What is the date?\", \"context\": \"The order total comes to $44.33\", \"answers\": None},\n",
        "    {\"question\": \"What is the amount?\", \"context\": \"The order total comes to $44.33\", \"answers\": \"$44.33\"},\n",
        "    {\"question\": \"What is the amount?\", \"context\": \"The last release was Wed Sept 25 2021\", \"answers\": None},\n",
        "]\n",
        "\n",
        "# Fine-tune QA model\n",
        "trainer = HFTrainer()\n",
        "model, tokenizer = trainer(\"distilbert-base-cased-distilled-squad\", data, task=\"question-answering\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='3' max='3' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [3/3 00:03, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1hzPnmrTaUjH"
      },
      "source": [
        "# Parse data into a structured table\n",
        "\n",
        "The next section takes a series of rows of text and runs a set of questions against each row. The answers are then used to build a pandas DataFrame."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4X5z3UjnAGe7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "outputId": "41015023-c7b7-4515-ad30-f1b2d8143ea2"
      },
      "source": [
        "# Input data\n",
        "context = [\"Released on 6/03/2021\",\n",
        "           \"Release delayed until the 11th of August\",\n",
        "           \"Documentation can be found here: neuml.github.io/txtai\",\n",
        "           \"The stock price fell to three dollars\",\n",
        "           \"Great day: closing price for March 23rd is $33.11, for details - https://finance.google.com\"]\n",
        "\n",
        "# Define column queries\n",
        "queries = [\"What is the url?\", \"What is the date?\", \"What is the amount?\"]\n",
        "\n",
        "# Extract fields\n",
        "questions = Questions(path=(model, tokenizer), gpu=True)\n",
        "results = [questions([question] * len(context), context) for question in queries]\n",
        "results.append(context)\n",
        "\n",
        "# Load into DataFrame\n",
        "pd.DataFrame(list(zip(*results)), columns=[\"URL\", \"Date\", \"Amount\", \"Text\"])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-32471005-00c7-4cc3-8624-4bd185b48d76\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>URL</th>\n",
              "      <th>Date</th>\n",
              "      <th>Amount</th>\n",
              "      <th>Text</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>None</td>\n",
              "      <td>6/03/2021</td>\n",
              "      <td>None</td>\n",
              "      <td>Released on 6/03/2021</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>None</td>\n",
              "      <td>11th of August</td>\n",
              "      <td>None</td>\n",
              "      <td>Release delayed until the 11th of August</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>neuml.github.io/txtai</td>\n",
              "      <td>None</td>\n",
              "      <td>None</td>\n",
              "      <td>Documentation can be found here: neuml.github....</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>None</td>\n",
              "      <td>None</td>\n",
              "      <td>three dollars</td>\n",
              "      <td>The stock price fell to three dollars</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>https://finance.google.com</td>\n",
              "      <td>March 23rd</td>\n",
              "      <td>$33.11</td>\n",
              "      <td>Great day: closing price for March 23rd is $33...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-32471005-00c7-4cc3-8624-4bd185b48d76')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-32471005-00c7-4cc3-8624-4bd185b48d76 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-32471005-00c7-4cc3-8624-4bd185b48d76');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "                          URL  ...                                               Text\n",
              "0                        None  ...                              Released on 6/03/2021\n",
              "1                        None  ...           Release delayed until the 11th of August\n",
              "2       neuml.github.io/txtai  ...  Documentation can be found here: neuml.github....\n",
              "3                        None  ...              The stock price fell to three dollars\n",
              "4  https://finance.google.com  ...  Great day: closing price for March 23rd is $33...\n",
              "\n",
              "[5 rows x 4 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mY1Le-pve5yi"
      },
      "source": [
        "# Add additional columns\n",
        "\n",
        "This method can be combined with other models to categorize, group or otherwise derive additional columns. The code below derives an additional sentiment column."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 479
        },
        "id": "0kyJUcrKe43a",
        "outputId": "cea0d481-3315-4ec8-a7f9-8a9c4942db10"
      },
      "source": [
        "# Add sentiment\n",
        "labels = Labels(path=\"distilbert-base-uncased-finetuned-sst-2-english\", dynamic=False)\n",
        "labels = [\"POSITIVE\" if x[0][0] == 1 else \"NEGATIVE\" for x in labels(context)]\n",
        "results.insert(len(results) - 1, labels)\n",
        "\n",
        "# Load into DataFrame\n",
        "pd.DataFrame(list(zip(*results)), columns=[\"URL\", \"Date\", \"Amount\", \"Sentiment\", \"Text\"])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-a5d090e1-83da-4189-84b1-5d89675d9800\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>URL</th>\n",
              "      <th>Date</th>\n",
              "      <th>Amount</th>\n",
              "      <th>Sentiment</th>\n",
              "      <th>Text</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>None</td>\n",
              "      <td>6/03/2021</td>\n",
              "      <td>None</td>\n",
              "      <td>POSITIVE</td>\n",
              "      <td>Released on 6/03/2021</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>None</td>\n",
              "      <td>11th of August</td>\n",
              "      <td>None</td>\n",
              "      <td>NEGATIVE</td>\n",
              "      <td>Release delayed until the 11th of August</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>neuml.github.io/txtai</td>\n",
              "      <td>None</td>\n",
              "      <td>None</td>\n",
              "      <td>NEGATIVE</td>\n",
              "      <td>Documentation can be found here: neuml.github....</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>None</td>\n",
              "      <td>None</td>\n",
              "      <td>three dollars</td>\n",
              "      <td>NEGATIVE</td>\n",
              "      <td>The stock price fell to three dollars</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>https://finance.google.com</td>\n",
              "      <td>March 23rd</td>\n",
              "      <td>$33.11</td>\n",
              "      <td>POSITIVE</td>\n",
              "      <td>Great day: closing price for March 23rd is $33...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a5d090e1-83da-4189-84b1-5d89675d9800')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-a5d090e1-83da-4189-84b1-5d89675d9800 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-a5d090e1-83da-4189-84b1-5d89675d9800');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "                          URL  ...                                               Text\n",
              "0                        None  ...                              Released on 6/03/2021\n",
              "1                        None  ...           Release delayed until the 11th of August\n",
              "2       neuml.github.io/txtai  ...  Documentation can be found here: neuml.github....\n",
              "3                        None  ...              The stock price fell to three dollars\n",
              "4  https://finance.google.com  ...  Great day: closing price for March 23rd is $33...\n",
              "\n",
              "[5 rows x 5 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    }
  ]
}